{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d95da963",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import os\n",
    "\n",
    "# ignore bnb warnings\n",
    "os.environ[\"BITSANDBYTES_NOWELCOME\"] = \"1\"\n",
    "\n",
    "if \"RANK_TABLE_FILE\" in os.environ:\n",
    "    del os.environ[\"RANK_TABLE_FILE\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "060725ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  return self._float_to_str(self.smallest_subnormal)\n",
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 1.289 seconds.\n",
      "Prefix dict has been built successfully.\n",
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/Cython/Compiler/Main.py:384: FutureWarning: Cython directive 'language_level' not set, using '3str' for now (Py3). This has changed from earlier releases! File: /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindnlp/transformers/models/graphormer/algos_graphormer.pyx\n",
      "  tree = Parsing.p_module(s, pxd, full_module_name)\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "from mindspore.common.api import _no_grad\n",
    "from mindnlp.core import nn, ops\n",
    "from mindnlp import peft\n",
    "import mindnlp.core.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6c13a889",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[MS_ALLOC_CONF]Runtime config:  enable_vmm:True  vmm_align_size:2MB\n"
     ]
    }
   ],
   "source": [
    "ms.manual_seed(0)\n",
    "X = ops.rand((1000, 20))\n",
    "y = (X.sum(1) > 10).long()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "621280fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_train = 800\n",
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "191b2059",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = ms.Tensor(X[:n_train])\n",
    "y_train = ms.Tensor(y[:n_train])\n",
    "train_dataset = ms.dataset.NumpySlicesDataset((X_train.asnumpy(), y_train.asnumpy()), column_names=[\"input_ids\", \"labels\"], shuffle=True)\n",
    "train_dataset = train_dataset.batch(batch_size)\n",
    "\n",
    "X_eval = ms.Tensor(X[n_train:])\n",
    "y_eval = ms.Tensor(y[n_train:])\n",
    "eval_dataset = ms.dataset.NumpySlicesDataset((X_eval.asnumpy(), y_eval.asnumpy()), column_names=[\"input_ids\", \"labels\"], shuffle=False)\n",
    "eval_dataset = eval_dataset.batch(batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "57b42a21",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': Tensor(shape=[64, 20], dtype=Float32, value=\n",
      "[[ 1.28077745e-01,  8.90942335e-01,  9.65436578e-01 ...  4.96876001e-01,  4.62052822e-01,  2.97103524e-01],\n",
      " [ 3.92949581e-01,  7.07879424e-01,  5.22751093e-01 ...  6.98211193e-02,  1.04892015e-01,  7.37131834e-01],\n",
      " [ 1.51267052e-01,  9.96178508e-01,  8.00368190e-01 ...  3.41349125e-01,  9.46143627e-01,  2.13831663e-01],\n",
      " ...\n",
      " [ 6.21959090e-01,  6.48394823e-01,  1.57515168e-01 ...  7.37506390e-01,  1.93000555e-01,  6.89157248e-01],\n",
      " [ 6.25082135e-01,  2.66415000e-01,  8.00211191e-01 ...  9.16727424e-01,  1.47469044e-01,  2.63221622e-01],\n",
      " [ 7.22806692e-01,  5.56603670e-01,  1.18087173e-01 ...  3.87998343e-01,  9.16841507e-01,  9.92253542e-01]]), 'labels': Tensor(shape=[64], dtype=Int64, value= [1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, \n",
      " 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, \n",
      " 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1])}\n"
     ]
    }
   ],
   "source": [
    "print(next(train_dataset.create_dict_iterator()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a028857c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, num_units_hidden=2000):\n",
    "        super().__init__()\n",
    "        self.seq = nn.Sequential(\n",
    "            nn.Linear(20, num_units_hidden),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(num_units_hidden, num_units_hidden),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(num_units_hidden, 2),\n",
    "            nn.LogSoftmax(dim=-1),\n",
    "        )\n",
    "\n",
    "    def forward(self, X):\n",
    "        return self.seq(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b8809390",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 0.002\n",
    "batch_size = 64\n",
    "num_epochs = 30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "074b857d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('', __main__.MLP),\n",
       " ('seq', mindnlp.core.nn.modules.container.Sequential),\n",
       " ('seq.0', mindnlp.core.nn.modules.linear.Linear),\n",
       " ('seq.1', mindnlp.core.nn.modules.activation.ReLU),\n",
       " ('seq.2', mindnlp.core.nn.modules.linear.Linear),\n",
       " ('seq.3', mindnlp.core.nn.modules.activation.ReLU),\n",
       " ('seq.4', mindnlp.core.nn.modules.linear.Linear),\n",
       " ('seq.5', mindnlp.core.nn.modules.activation.LogSoftmax)]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[(n, type(m)) for n, m in MLP().named_modules()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "19fda38d",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = peft.LoraConfig(\n",
    "    r=8,\n",
    "    target_modules=[\"seq.0\", \"seq.2\"],\n",
    "    modules_to_save=[\"seq.4\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4fc7e57c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "seq.0.weight: (2000, 20)\n",
      "seq.0.bias: (2000,)\n",
      "seq.2.weight: (2000, 2000)\n",
      "seq.2.bias: (2000,)\n",
      "seq.4.weight: (2, 2000)\n",
      "seq.4.bias: (2,)\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.core import optim\n",
    "from mindnlp.common.optimization import get_linear_schedule_with_warmup\n",
    "\n",
    "module = MLP()\n",
    "for name, param in module.named_parameters():\n",
    "    print(f\"{name}: {param.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "53e60749",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "base_model.model.seq.0.base_layer.weight: (2000, 20)\n",
      "base_model.model.seq.0.base_layer.bias: (2000,)\n",
      "base_model.model.seq.0.lora_A.default.weight: (8, 20)\n",
      "base_model.model.seq.0.lora_B.default.weight: (2000, 8)\n",
      "base_model.model.seq.2.base_layer.weight: (2000, 2000)\n",
      "base_model.model.seq.2.base_layer.bias: (2000,)\n",
      "base_model.model.seq.2.lora_A.default.weight: (8, 2000)\n",
      "base_model.model.seq.2.lora_B.default.weight: (2000, 8)\n",
      "base_model.model.seq.4.original_cell.weight: (2, 2000)\n",
      "base_model.model.seq.4.original_cell.bias: (2,)\n",
      "base_model.model.seq.4.modules_to_save.default.weight: (2, 2000)\n",
      "base_model.model.seq.4.modules_to_save.default.bias: (2,)\n"
     ]
    }
   ],
   "source": [
    "model = peft.get_peft_model(module, config)\n",
    "for name, param in model.named_parameters():\n",
    "    print(f\"{name}: {param.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b55d27db",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Adam (\n",
      "Parameter Group 0\n",
      "    amsgrad: False\n",
      "    betas: (0.9, 0.999)\n",
      "    eps: 1e-08\n",
      "    lr: 0.002\n",
      "    maximize: False\n",
      "    weight_decay: 0\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "optimizer = optim.Adam([param for name, param in model.named_parameters() if \"lora\" in name], lr=lr)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "print(optimizer)\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=(len(train_dataset) * num_epochs),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "22655c8e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 52,162 || all params: 4,100,164 || trainable%: 1.2721930147184357\n"
     ]
    }
   ],
   "source": [
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cfbbe7e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:02<00:00,  5.06it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 113.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=0   train_loss_total=0.6654  eval_loss_total=0.7004\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 32.63it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 184.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=1   train_loss_total=0.5330  eval_loss_total=0.5407\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 33.84it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 183.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=2   train_loss_total=0.3814  eval_loss_total=0.2996\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 32.22it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 164.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=3   train_loss_total=0.2496  eval_loss_total=0.2209\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 32.67it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 158.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=4   train_loss_total=0.1907  eval_loss_total=0.1977\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 33.03it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 190.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=5   train_loss_total=0.1369  eval_loss_total=0.2341\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 33.51it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 217.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=6   train_loss_total=0.1331  eval_loss_total=0.3139\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 32.86it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 199.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=7   train_loss_total=0.1187  eval_loss_total=0.1651\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 32.85it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 171.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=8   train_loss_total=0.0792  eval_loss_total=0.1644\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 32.92it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 185.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=9   train_loss_total=0.0670  eval_loss_total=0.1457\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 31.95it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 196.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=10  train_loss_total=0.1180  eval_loss_total=0.1758\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 32.77it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 170.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=11  train_loss_total=0.1103  eval_loss_total=0.1946\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 33.85it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 206.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=12  train_loss_total=0.0605  eval_loss_total=0.1850\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 33.00it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 190.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=13  train_loss_total=0.0523  eval_loss_total=0.1422\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 33.77it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 203.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=14  train_loss_total=0.0617  eval_loss_total=0.1399\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 32.98it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 171.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=15  train_loss_total=0.0472  eval_loss_total=0.1663\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 34.10it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 167.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=16  train_loss_total=0.0301  eval_loss_total=0.1447\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 32.94it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 201.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=17  train_loss_total=0.0235  eval_loss_total=0.1347\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 34.46it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 186.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=18  train_loss_total=0.0209  eval_loss_total=0.1373\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 32.72it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 193.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=19  train_loss_total=0.0210  eval_loss_total=0.1375\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 32.50it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 176.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=20  train_loss_total=0.0174  eval_loss_total=0.1315\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 33.07it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 163.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=21  train_loss_total=0.0162  eval_loss_total=0.1313\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 31.95it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 166.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=22  train_loss_total=0.0155  eval_loss_total=0.1289\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 31.23it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 168.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=23  train_loss_total=0.0151  eval_loss_total=0.1296\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 32.86it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 187.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=24  train_loss_total=0.0141  eval_loss_total=0.1364\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 33.85it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 189.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=25  train_loss_total=0.0146  eval_loss_total=0.1296\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 34.18it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 164.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=26  train_loss_total=0.0136  eval_loss_total=0.1305\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 32.74it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 203.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=27  train_loss_total=0.0128  eval_loss_total=0.1307\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 33.05it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 167.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=28  train_loss_total=0.0129  eval_loss_total=0.1301\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13/13 [00:00<00:00, 32.88it/s]\n",
      "100%|██████████| 4/4 [00:00<00:00, 179.64it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=29  train_loss_total=0.0124  eval_loss_total=0.1313\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "from mindnlp.core import value_and_grad\n",
    "\n",
    "def forward_fn(**batch):\n",
    "    outputs = model(batch[\"input_ids\"])\n",
    "    loss = criterion(outputs, batch[\"labels\"])\n",
    "    return loss\n",
    "\n",
    "grad_fn = value_and_grad(forward_fn, model.trainable_params())\n",
    "\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.set_train(True)\n",
    "    train_loss = 0\n",
    "    train_total_size = train_dataset.get_dataset_size()\n",
    "    for step, batch in enumerate(tqdm(train_dataset.create_dict_iterator(), total=train_total_size)):\n",
    "        optimizer.zero_grad()\n",
    "        loss = grad_fn(**batch)\n",
    "        optimizer.step()\n",
    "\n",
    "        train_loss += loss.float()\n",
    "        lr_scheduler.step()\n",
    "\n",
    "    model.set_train(False)\n",
    "    eval_loss = 0\n",
    "    eval_total_size = eval_dataset.get_dataset_size()\n",
    "    for step, batch in enumerate(tqdm(eval_dataset.create_dict_iterator(), total=eval_total_size)):\n",
    "        with ms._no_grad():\n",
    "            outputs = model(batch[\"input_ids\"])\n",
    "        loss = criterion(outputs, batch[\"labels\"])\n",
    "        eval_loss += loss.float()    \n",
    "\n",
    "    eval_loss_total = (eval_loss / len(eval_dataset)).item()\n",
    "    train_loss_total = (train_loss / len(train_dataset)).item()\n",
    "    print(f\"{epoch=:<2}  {train_loss_total=:.4f}  {eval_loss_total=:.4f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
